/** * * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.openejb.client; import org.w3c.dom.Element; import javax.jws.WebService; import javax.xml.bind.JAXBContext; import javax.xml.namespace.QName; import javax.xml.transform.Source; import javax.xml.ws.BindingProvider; import javax.xml.ws.Dispatch; import javax.xml.ws.Endpoint; import javax.xml.ws.EndpointReference; import javax.xml.ws.Service; import javax.xml.ws.WebServiceException; import javax.xml.ws.WebServiceFeature; import javax.xml.ws.handler.HandlerResolver; import javax.xml.ws.soap.SOAPBinding; import javax.xml.ws.spi.Provider; import javax.xml.ws.spi.ServiceDelegate; import javax.xml.ws.wsaddressing.W3CEndpointReference; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.URL; import java.util.ArrayList; import java.util.Collections; import java.util.Enumeration; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; public class JaxWsProviderWrapper extends Provider { private static final Logger logger = Logger.getLogger("OpenEJB.client"); // // Magic to get our proider wrapper installed with the PortRefData // private static final ThreadLocal<ProviderWrapperData> threadPortRefs = new ThreadLocal<ProviderWrapperData>(); private static final String JAXWSPROVIDER_PROPERTY = Provider.class.getName(); public static void beforeCreate(final List<PortRefMetaData> portRefMetaDatas) { // Axis JAXWS api is non compliant and checks system property before classloader // so we replace system property so this wrapper is selected. The original value // is saved into an openejb property so we can load the class in the find method final String oldProperty = System.getProperty(JAXWSPROVIDER_PROPERTY); if (oldProperty != null && !oldProperty.equals(JaxWsProviderWrapper.class.getName())) { System.setProperty("openejb." + JAXWSPROVIDER_PROPERTY, oldProperty); } if (oldProperty == null || !oldProperty.equals(JaxWsProviderWrapper.class.getName())) { System.setProperty(JAXWSPROVIDER_PROPERTY, JaxWsProviderWrapper.class.getName()); } final ClassLoader oldClassLoader = Thread.currentThread().getContextClassLoader(); if (oldClassLoader != null) { Thread.currentThread().setContextClassLoader(new ProviderClassLoader(oldClassLoader)); } else { Thread.currentThread().setContextClassLoader(new ProviderClassLoader()); } threadPortRefs.set(new ProviderWrapperData(portRefMetaDatas, oldClassLoader)); } public static void afterCreate() { Thread.currentThread().setContextClassLoader(threadPortRefs.get().callerClassLoader); threadPortRefs.set(null); } private static class ProviderWrapperData { private final List<PortRefMetaData> portRefMetaData; private final ClassLoader callerClassLoader; public ProviderWrapperData(final List<PortRefMetaData> portRefMetaDatas, final ClassLoader callerClassLoader) { this.portRefMetaData = portRefMetaDatas; this.callerClassLoader = callerClassLoader; } } // // Provider wappre implementation // private final Provider delegate; private final List<PortRefMetaData> portRefs; public JaxWsProviderWrapper() { delegate = findProvider(); portRefs = threadPortRefs.get().portRefMetaData; } public Provider getDelegate() { return delegate; } @SuppressWarnings("unchecked") @Override public ServiceDelegate createServiceDelegate(final URL wsdlDocumentLocation, final QName serviceName, final Class serviceClass) { ServiceDelegate serviceDelegate = delegate.createServiceDelegate(wsdlDocumentLocation, serviceName, serviceClass); serviceDelegate = new ServiceDelegateWrapper(serviceDelegate); return serviceDelegate; } @Override public Endpoint createEndpoint(final String bindingId, final Object implementor) { return delegate.createEndpoint(bindingId, implementor); } @Override public Endpoint createAndPublishEndpoint(final String address, final Object implementor) { return delegate.createAndPublishEndpoint(address, implementor); } @Override public W3CEndpointReference createW3CEndpointReference(final String address, final QName serviceName, final QName portName, final List<Element> metadata, final String wsdlDocumentLocation, final List<Element> referenceParameters) { return (W3CEndpointReference) invoke21Delegate(delegate, createW3CEndpointReference, address, serviceName, portName, metadata, wsdlDocumentLocation, referenceParameters); } @Override public EndpointReference readEndpointReference(final Source source) { return (EndpointReference) invoke21Delegate(delegate, readEndpointReference, source); } @Override @SuppressWarnings({"unchecked"}) public <T> T getPort(final EndpointReference endpointReference, final Class<T> serviceEndpointInterface, final WebServiceFeature... features) { return (T) invoke21Delegate(delegate, providerGetPort, endpointReference, serviceEndpointInterface, features); } private class ServiceDelegateWrapper extends ServiceDelegate { private final ServiceDelegate serviceDelegate; public ServiceDelegateWrapper(final ServiceDelegate serviceDelegate) { this.serviceDelegate = serviceDelegate; } @Override public <T> T getPort(final QName portName, final Class<T> serviceEndpointInterface) { final T t = serviceDelegate.getPort(portName, serviceEndpointInterface); setProperties((BindingProvider) t, portName); return t; } @Override public <T> T getPort(final Class<T> serviceEndpointInterface) { final T t = serviceDelegate.getPort(serviceEndpointInterface); QName qname = null; if (serviceEndpointInterface.isAnnotationPresent(WebService.class)) { final WebService webService = serviceEndpointInterface.getAnnotation(WebService.class); final String targetNamespace = webService.targetNamespace(); final String name = webService.name(); if (targetNamespace != null && targetNamespace.length() > 0 && name != null && name.length() > 0) { qname = new QName(targetNamespace, name); } } setProperties((BindingProvider) t, qname); return t; } @Override public void addPort(final QName portName, final String bindingId, final String endpointAddress) { serviceDelegate.addPort(portName, bindingId, endpointAddress); } @Override public <T> Dispatch<T> createDispatch(final QName portName, final Class<T> type, final Service.Mode mode) { final Dispatch<T> dispatch = serviceDelegate.createDispatch(portName, type, mode); setProperties(dispatch, portName); return dispatch; } @Override public Dispatch<Object> createDispatch(final QName portName, final JAXBContext context, final Service.Mode mode) { final Dispatch<Object> dispatch = serviceDelegate.createDispatch(portName, context, mode); setProperties(dispatch, portName); return dispatch; } @Override @SuppressWarnings({"unchecked"}) public <T> Dispatch<T> createDispatch(final QName portName, final Class<T> type, final Service.Mode mode, final WebServiceFeature... features) { return (Dispatch<T>) invoke21Delegate(serviceDelegate, createDispatchInterface, portName, type, mode, features); } @Override @SuppressWarnings({"unchecked"}) public Dispatch<java.lang.Object> createDispatch(final QName portName, final JAXBContext context, final Service.Mode mode, final WebServiceFeature... features) { return (Dispatch<Object>) invoke21Delegate(serviceDelegate, createDispatchJaxBContext, portName, context, mode, features); } @Override @SuppressWarnings({"unchecked"}) public Dispatch<Object> createDispatch( final EndpointReference endpointReference, final JAXBContext context, final Service.Mode mode, final WebServiceFeature... features) { return (Dispatch<Object>) invoke21Delegate(serviceDelegate, createDispatchReferenceJaxB, endpointReference, context, mode, features); } @Override @SuppressWarnings({"unchecked"}) public <T> Dispatch<T> createDispatch(final EndpointReference endpointReference, final java.lang.Class<T> type, final Service.Mode mode, final WebServiceFeature... features) { return (Dispatch<T>) invoke21Delegate(serviceDelegate, createDispatchReferenceClass, endpointReference, type, mode, features); } @Override @SuppressWarnings({"unchecked"}) public <T> T getPort(final QName portName, final Class<T> serviceEndpointInterface, final WebServiceFeature... features) { return (T) invoke21Delegate(serviceDelegate, serviceGetPortByQName, portName, serviceEndpointInterface, features); } @Override @SuppressWarnings({"unchecked"}) public <T> T getPort(final EndpointReference endpointReference, final Class<T> serviceEndpointInterface, final WebServiceFeature... features) { return (T) invoke21Delegate(serviceDelegate, serviceGetPortByEndpointReference, endpointReference, serviceEndpointInterface, features); } @Override @SuppressWarnings({"unchecked"}) public <T> T getPort(final Class<T> serviceEndpointInterface, final WebServiceFeature... features) { return (T) invoke21Delegate(serviceDelegate, serviceGetPortByInterface, serviceEndpointInterface, features); } @Override public QName getServiceName() { return serviceDelegate.getServiceName(); } @Override public Iterator<QName> getPorts() { return serviceDelegate.getPorts(); } @Override public URL getWSDLDocumentLocation() { return serviceDelegate.getWSDLDocumentLocation(); } @Override public HandlerResolver getHandlerResolver() { return serviceDelegate.getHandlerResolver(); } @Override public void setHandlerResolver(final HandlerResolver handlerResolver) { serviceDelegate.setHandlerResolver(handlerResolver); } @Override public Executor getExecutor() { return serviceDelegate.getExecutor(); } @Override public void setExecutor(final Executor executor) { serviceDelegate.setExecutor(executor); } private void setProperties(final BindingProvider proxy, final QName qname) { for (final PortRefMetaData portRef : portRefs) { Class<?> intf = null; if (portRef.getServiceEndpointInterface() != null) { try { intf = proxy.getClass().getClassLoader().loadClass(portRef.getServiceEndpointInterface()); } catch (ClassNotFoundException e) { logger.log(Level.INFO, "Not loading: " + portRef.getServiceEndpointInterface()); } } if ((qname != null && qname.equals(portRef.getQName())) || (intf != null && intf.isInstance(proxy))) { // set address if (!portRef.getAddresses().isEmpty()) { proxy.getRequestContext().put(BindingProvider.ENDPOINT_ADDRESS_PROPERTY, portRef.getAddresses().get(0)); } // set mtom final boolean enableMTOM = portRef.isEnableMtom(); if (enableMTOM && proxy.getBinding() instanceof SOAPBinding) { ((SOAPBinding) proxy.getBinding()).setMTOMEnabled(enableMTOM); } // set properties for (final Map.Entry<Object, Object> entry : portRef.getProperties().entrySet()) { final String name = (String) entry.getKey(); final String value = (String) entry.getValue(); proxy.getRequestContext().put(name, value); } return; } } } } private static Provider findProvider() { ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); if (classLoader == null) { classLoader = ClassLoader.getSystemClassLoader(); } // 0. System.getProperty("openejb.javax.xml.ws.spi.Provider") // This is so those using old axis rules still work as expected String providerClass = System.getProperty("openejb." + JAXWSPROVIDER_PROPERTY); Provider provider = createProviderInstance(providerClass, classLoader); if (provider != null) { return provider; } // 1. META-INF/services/javax.xml.ws.spi.Provider try { for (final URL url : Collections.list(classLoader.getResources("META-INF/services/" + JAXWSPROVIDER_PROPERTY))) { BufferedReader in = null; try { in = new BufferedReader(new InputStreamReader(url.openStream())); providerClass = in.readLine(); provider = createProviderInstance(providerClass, classLoader); if (provider != null) { return provider; } } catch (Exception ignored) { } finally { if (in != null) { try { in.close(); } catch (Throwable e) { //ignore } } } } } catch (Exception ingored) { logger.log(Level.INFO, "No META-INF/services/javax.xml.ws.spi.Provider found"); } // 2. $java.home/lib/jaxws.properties final String javaHome = System.getProperty("java.home"); final File jaxrpcPropertiesFile = new File(new File(javaHome, "lib"), "jaxrpc.properties"); if (jaxrpcPropertiesFile.exists()) { InputStream in = null; try { in = new FileInputStream(jaxrpcPropertiesFile); final Properties properties = new Properties(); properties.load(in); providerClass = properties.getProperty(JAXWSPROVIDER_PROPERTY); provider = createProviderInstance(providerClass, classLoader); if (provider != null) { return provider; } } catch (Exception ignored) { } finally { if (in != null) { try { in.close(); } catch (Throwable e) { //Ignore } } } } // 3. System.getProperty("javax.xml.ws.spi.Provider") providerClass = System.getProperty(JAXWSPROVIDER_PROPERTY); provider = createProviderInstance(providerClass, classLoader); if (provider != null) { return provider; } // 4. Use javax.xml.ws.spi.Provider default try { // disable the OpenEJB JaxWS provider if (classLoader instanceof ProviderClassLoader) { ((ProviderClassLoader) classLoader).enabled = false; } System.getProperties().remove(JAXWSPROVIDER_PROPERTY); provider = Provider.provider(); if (provider != null && !provider.getClass().getName().equals(JaxWsProviderWrapper.class.getName())) { return provider; } } finally { // reenable the OpenEJB JaxWS provider System.setProperty(JAXWSPROVIDER_PROPERTY, providerClass); if (classLoader instanceof ProviderClassLoader) { ((ProviderClassLoader) classLoader).enabled = true; } } throw new WebServiceException("No " + JAXWSPROVIDER_PROPERTY + " implementation found"); } private static Provider createProviderInstance(final String providerClass, final ClassLoader classLoader) { if (providerClass != null && providerClass.length() > 0 && !providerClass.equals(JaxWsProviderWrapper.class.getName())) { try { final Class<? extends Provider> clazz = classLoader.loadClass(providerClass).asSubclass(Provider.class); return clazz.newInstance(); } catch (Throwable e) { logger.log(Level.WARNING, "Unable to construct provider implementation " + providerClass, e); } } return null; } private static class ProviderClassLoader extends ClassLoader { private static final String PROVIDER_RESOURCE = "META-INF/services/" + JAXWSPROVIDER_PROPERTY; private static final URL PROVIDER_URL; static { File tempFile = null; try { try { tempFile = File.createTempFile("openejb-jaxws-provider", "tmp"); } catch (Throwable e) { final File dir = new File("tmp"); if (!dir.exists() && !dir.mkdirs()) { throw new Exception("Failed to create: " + dir.getAbsolutePath()); } tempFile = File.createTempFile("openejb-jaxws-provider", "tmp", dir); } tempFile.deleteOnExit(); OutputStream out = null; try { out = new FileOutputStream(tempFile); out.write(JaxWsProviderWrapper.class.getName().getBytes()); } finally { if (null != out) { try { out.close(); } catch (Throwable e) { //Ignore } } } out.close(); PROVIDER_URL = tempFile.toURI().toURL(); } catch (Throwable e) { throw new ClientRuntimeException("Failed to create openejb-jaxws-provider file: " + tempFile, e); } } public boolean enabled = true; public ProviderClassLoader() { } public ProviderClassLoader(final ClassLoader parent) { super(parent); } @Override public Enumeration<URL> getResources(final String name) throws IOException { Enumeration<URL> resources = super.getResources(name); if (enabled && PROVIDER_RESOURCE.equals(name)) { final ArrayList<URL> list = new ArrayList<URL>(); list.add(PROVIDER_URL); list.addAll(Collections.list(resources)); resources = Collections.enumeration(list); } return resources; } @Override public URL getResource(final String name) { if (enabled && PROVIDER_RESOURCE.equals(name)) { return PROVIDER_URL; } return super.getResource(name); } } // // Delegate methods for JaxWS 2.1 // private static Object invoke21Delegate(final Object delegate, final Method method, final Object... args) { if (method == null) { throw new UnsupportedOperationException("JaxWS 2.1 APIs are not supported"); } try { return method.invoke(delegate, args); } catch (IllegalAccessException e) { throw new WebServiceException(e); } catch (InvocationTargetException e) { if (e.getCause() != null) { throw new WebServiceException(e.getCause()); } throw new WebServiceException(e); } } // Provider methods private static final Method createW3CEndpointReference; private static final Method providerGetPort; private static final Method readEndpointReference; // ServiceDelegate methods private static final Method createDispatchReferenceJaxB; private static final Method createDispatchReferenceClass; private static final Method createDispatchInterface; private static final Method createDispatchJaxBContext; private static final Method serviceGetPortByEndpointReference; private static final Method serviceGetPortByQName; private static final Method serviceGetPortByInterface; static { Method method = null; try { method = Provider.class.getMethod("createW3CEndpointReference", String.class, QName.class, QName.class, List.class, String.class, List.class); } catch (NoSuchMethodException e) { //Ignore } createW3CEndpointReference = method; method = null; try { method = Provider.class.getMethod("getPort", EndpointReference.class, Class.class, WebServiceFeature[].class); } catch (NoSuchMethodException e) { //Ignore } providerGetPort = method; method = null; try { method = Provider.class.getMethod("readEndpointReference", Source.class); } catch (NoSuchMethodException e) { //Ignore } readEndpointReference = method; method = null; try { method = ServiceDelegate.class.getMethod("createDispatch", EndpointReference.class, JAXBContext.class, Service.Mode.class, WebServiceFeature[].class); } catch (NoSuchMethodException e) { //Ignore } createDispatchReferenceJaxB = method; method = null; try { method = ServiceDelegate.class.getMethod("createDispatch", EndpointReference.class, Class.class, Service.Mode.class, WebServiceFeature[].class); } catch (NoSuchMethodException e) { //Ignore } createDispatchReferenceClass = method; method = null; try { method = ServiceDelegate.class.getMethod("createDispatch", QName.class, JAXBContext.class, Service.Mode.class, WebServiceFeature[].class); } catch (NoSuchMethodException e) { //Ignore } createDispatchJaxBContext = method; method = null; try { method = ServiceDelegate.class.getMethod("createDispatch", QName.class, Class.class, Service.Mode.class, WebServiceFeature[].class); } catch (NoSuchMethodException e) { //Ignore } createDispatchInterface = method; method = null; try { method = ServiceDelegate.class.getMethod("getPort", EndpointReference.class, Class.class, WebServiceFeature[].class); } catch (NoSuchMethodException e) { //Ignore } serviceGetPortByEndpointReference = method; method = null; try { method = ServiceDelegate.class.getMethod("getPort", QName.class, Class.class, WebServiceFeature[].class); } catch (NoSuchMethodException e) { //Ignore } serviceGetPortByQName = method; method = null; try { method = ServiceDelegate.class.getMethod("getPort", Class.class, WebServiceFeature[].class); } catch (NoSuchMethodException e) { //Ignore } serviceGetPortByInterface = method; } }